package io.quarkus.websockets.next.test.broadcast;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.vertx.core.Vertx;
import io.vertx.core.http.WebSocket;
import io.vertx.core.http.WebSocketClient;

public class BroadcastOnOpenTest {

    @RegisterExtension
    public static final QuarkusUnitTest test = new QuarkusUnitTest()
            .withApplicationRoot(root -> {
                root.addClasses(Lo.class, LoBlocking.class, LoMultiProduce.class);
            });

    @TestHTTPResource("lo")
    URI loUri;

    @TestHTTPResource("lo-blocking")
    URI loBlockingUri;

    @TestHTTPResource("lo-multi-produce")
    URI loMultiProduceUri;

    @Inject
    Vertx vertx;

    @Test
    public void testLo() throws Exception {
        assertBroadcast(loUri);
    }

    @Test
    public void testLoBlocking() throws Exception {
        assertBroadcast(loBlockingUri);
    }

    @Test
    public void testLoMultiBidi() throws Exception {
        assertBroadcast(loMultiProduceUri);
    }

    public void assertBroadcast(URI testUri) throws Exception {
        WebSocketClient client1 = vertx.createWebSocketClient();
        WebSocketClient client2 = vertx.createWebSocketClient();
        try {
            CountDownLatch c1MessageLatch = new CountDownLatch(1);
            CountDownLatch c2MessageLatch = new CountDownLatch(2);
            List<String> messages = new CopyOnWriteArrayList<>();
            client1
                    .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + "/C1")
                    .onComplete(r -> {
                        if (r.succeeded()) {
                            WebSocket ws = r.result();
                            ws.textMessageHandler(msg -> {
                                messages.add(msg + ":client1");
                                if (msg.equals("c1")) {
                                    c1MessageLatch.countDown();
                                } else if (msg.equals("c2")) {
                                    // onOpen callback from the second client
                                    c2MessageLatch.countDown();
                                }
                            });
                            // Trigger emission for LoMultiProduce
                            ws.writeTextMessage("foo");
                        } else {
                            throw new IllegalStateException(r.cause());
                        }
                    });
            assertTrue(c1MessageLatch.await(5, TimeUnit.SECONDS));
            assertEquals(1, messages.size());
            assertEquals("c1:client1", messages.get(0));
            messages.clear();
            // Now connect the second client
            client2
                    .connect(testUri.getPort(), testUri.getHost(), testUri.getPath() + "/C2")
                    .onComplete(r -> {
                        if (r.succeeded()) {
                            WebSocket ws = r.result();
                            ws.textMessageHandler(msg -> {
                                messages.add(msg + ":client2");
                                c2MessageLatch.countDown();
                            });
                            // Trigger emission for LoMultiProduce
                            ws.writeTextMessage("foo");
                        } else {
                            throw new IllegalStateException(r.cause());
                        }
                    });
            assertTrue(c2MessageLatch.await(5, TimeUnit.SECONDS), "Messages: " + messages);
            // onOpen should be broadcasted to both clients
            assertEquals(2, messages.size(), "Messages: " + messages);
            assertEquals("c2", messages.get(0).substring(0, 2));
            assertEquals("c2", messages.get(1).substring(0, 2));
        } finally {
            client1.close().toCompletionStage().toCompletableFuture().get();
            client2.close().toCompletionStage().toCompletableFuture().get();
        }
    }

}
